# coding: utf-8
# 二叉搜索树(BST)实现(val+left+right+parent)，必须掌握，包括前中后序遍历
# 二叉搜索树的最近公共祖先问题

# 二叉搜索树是具有有以下性质的二叉树：
#   （1）若左子树不为空，则左子树上所有节点的值均小于或等于它的根节点的值。
#   （2）若右子树不为空，则右子树上所有节点的值均大于或等于它的根节点的值
#    (3)BST不能像链表有一个表头节点，必须讨论root和非root
# https://www.jianshu.com/p/a0a354bc7514

import networkx as nx
import matplotlib.pyplot as plt

# 节点类
class TreeNode(object):
    def __init__(self, data, left=None, right=None, parent=None):
        #值
        self.data=data
        # 左子树
        self.left=left
        # 右子树
        self.right=right
        # 父节点
        self.parent=parent
    
class BST(object):
    def __init__(self):
        self.root=None
        #节点数
        self.size=0

    # 从指定的节点开始查找
    def findFrom(self,key,node):
        if not node:
            return None
        elif node.data==key:
            return node
        elif node.data>key:
            return self.findFrom(key, node.left)
        else:
            return self.findFrom(key, node.right)

    # 从root开始查找
    def find(self,key):
        if self.root:
            result = self.findFrom(key, self.root)
        else:
            result = None
        return result

    def printFrom(self,node):
        if node:
            self.printFrom(node.left)
            print(node.data)
            self.printFrom(node.right)

    #中序遍历
    def printTree(self):
        if self.size<=0:
            print("empty tree")
        else:
            self.printFrom(self.root)

    #经典算法--树的高度！！
    def height(self, node):
        if not node:
            return 0
        leftH=self.height(node.left)
        rightH=self.height(node.right)
        return max(leftH, rightH) + 1

    def insert(self,key):
        newNode=TreeNode(key)
        #插入需要讨论是否有根节点
        if not self.root:
            self.root=newNode
            self.size+=1
        else:
            curNode=self.root
            #查找插入的位置
            while True:
                #已经存在，结束
                if curNode.data==key:
                    break
                #左子树
                elif curNode.data>key:
                    if curNode.left:
                        curNode=curNode.left
                    else:
                        curNode.left=newNode
                        newNode.parent=curNode
                        self.size+=1
                        break
                else:
                    if curNode.right:
                        curNode = curNode.right
                    else:
                        curNode.right=newNode
                        newNode.parent=curNode
                        self.size+=1
                        break   
        return self

    # 使用networkx创建graph
    # TODO 熟悉networkx和matplotlib
    def create_graph(self, G, node, pos={}, x=0, y=0, layer=1):
        pos[node.data] = (x, y)
        if node.left:
            G.add_edge(node.data, node.left.data)
            l_x, l_y = x - (0.5) ** layer, y - 1
            l_layer = layer + 1
            self.create_graph(G, node.left, x=l_x, y=l_y, pos=pos, layer=l_layer)
        if node.right:
            G.add_edge(node.data, node.right.data)
            r_x, r_y = x + (0.5) ** layer, y - 1
            r_layer = layer + 1
            self.create_graph(G, node.right, x=r_x, y=r_y, pos=pos, layer=r_layer)
        return (G, pos)
    
    # 使用matplotlib的pyplot画图
    # https://zhuanlan.zhihu.com/p/35574577
    def draw(self, node):   # 以某个节点为根画图
        #有向图 direction graph
        graph = nx.DiGraph()
        graph, pos = self.create_graph(graph, node)
        fig, ax = plt.subplots(figsize=(10, 10))  # 比例可以根据树的深度适当调节
        nx.draw_networkx(graph, pos, ax=ax, node_size=300, node_color='y')
        plt.show()
    
    # 返回(是否平衡, 树高)
    # python中返回元组和拆分返回值元组都不需要小括号
    def isBalancedFromNode(self, node):
        if not node:
            return True, 0;
        leftB, leftH=self.isBalancedFromNode(node.left)
        rightB, rightH=self.isBalancedFromNode(node.right)
        h=max(leftH, rightH)+1
        ba= leftB and rightB and abs(leftH-rightH)<=1
        return ba, h
    
    #判断BST是否平衡
    def isBalanced(self):
        if self.size<=1:
            return True
        else:
            ba,h=self.isBalancedFromNode(self.root)
            print(h)
            return ba

bst=BST()
#bst.insert(1).insert(2).insert(3).insert(4).insert(5)
bst.insert(3).insert(2).insert(4).insert(6).insert(5)
bst.printTree()
print("tree height", bst.height(bst.root))
bst.draw(bst.root)
print("isBalanced", bst.isBalanced())

